{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchtext\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from torchtext.vocab import Vectors\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = torchtext.data.Field(include_lengths = True)\n",
    "label = torchtext.data.Field(sequential=False)\n",
    "train, val, test = torchtext.datasets.SST.splits(text, label, filter_pred=lambda ex: ex.label != 'neutral')\n",
    "text.build_vocab(train)\n",
    "label.build_vocab(train)\n",
    "train_iter, val_iter, test_iter = torchtext.data.BucketIterator.splits((train, val, test), batch_size=10, device=-1, repeat = False)\n",
    "url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec'\n",
    "text.vocab.load_vectors(vectors=Vectors('wiki.simple.vec', url=url))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(train_iter))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Variable containing:\n",
       "    132   3096     29   3112     14     14     65   3947   5485   8702\n",
       "   1065     10     10   3073   2210     16     13    642   3114   8719\n",
       "     57  15812   4634     10   7406      9     10      9      5     10\n",
       "     48    348     12   3943     28  11715    125      7     46   1824\n",
       "   4619    173      7      5   2568     31      8     18      6   2996\n",
       "    245      9   7138   2963   1650     23    246     11     45      6\n",
       "      9    145    142      3      5  12660     12     90    203    436\n",
       "   1971     19     11     40   1531   1343   1332     13     32   2442\n",
       "    977      7  13958    126   1680    136    288    667     15    843\n",
       "     15   2613   6581    117      3      6    119     20    119      7\n",
       "      7  11028    973    292   2680  10131     79      4     15   4167\n",
       "  10680   2156   6581      3    358   2009     10    597     45   2251\n",
       "     20      6      6   4574    148      3    266      8   5917      8\n",
       "      4   2582   5094      8   4633   4475      8   3344      3      7\n",
       "   1372   2795   1151      4     28     31   1206    138   3831    547\n",
       "      6     11   1172    530     17   3831      3     20   8259     52\n",
       "    737   3551   5757      6      7   5490     19      7     32     17\n",
       "   6644     35   3749      4   6602     15   1765    186    726    160\n",
       "     12     15    217    995  12244      4      8    174      4      7\n",
       "  14499     71      5      6      3   7294   1335     49   1002   1857\n",
       "      5     88      7    343     77      5      4     85   6317      3\n",
       "   2596    190   7150     10  12698     55   2648    818      6   3383\n",
       "    780    787    400    121  15040  10928      6    952   2471   2684\n",
       "     28    233     11     11      3   1186     22      8     10      6\n",
       "   2359     40    608   1687      5   5605    747      4   4751    532\n",
       "     19   2119    226      8      7   3127      5   5968      3     11\n",
       "   1959   8041    313   1421   4448     43   1042    961      5     13\n",
       "   1097   2169      6      4   3343    170      6   3221     13   1086\n",
       "    359     32      4   6372      6    141     22  14500     10     15\n",
       "      6      7  13883      6  12597    210    136      5     96    192\n",
       "   2482  11554    758    672     10     12    205     37     30   1679\n",
       "    977    109  10078  10745   4595      4     25    157      4   5999\n",
       "      3  12763      6     31    838    123    189     15    185      8\n",
       "      7     31     51  15680      8     16      7     71     12      4\n",
       "   5892     22   1604    708   1051      5    357     94     74  12031\n",
       "   1387    117     10  12045    187   1529   1123     49    434      6\n",
       "     12    517     50     17      7   5676     17    283      6      4\n",
       "      4  14760    107      7   4026      6      7     33     14   1409\n",
       "   7234      2     86   2843     49      4    186   1689   2515   3089\n",
       "      1      1      2    278    147     86    970    275    221    250\n",
       "      1      1      1      2      2      2      2      2      2      2\n",
       " [torch.LongTensor of size 41x10], \n",
       "  39\n",
       "  39\n",
       "  40\n",
       "  41\n",
       "  41\n",
       "  41\n",
       "  41\n",
       "  41\n",
       "  41\n",
       "  41\n",
       " [torch.LongTensor of size 10])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "defaultdict(<function torchtext.vocab._default_unk_index>,\n",
       "            {'<unk>': 0,\n",
       "             '<pad>': 1,\n",
       "             '.': 2,\n",
       "             ',': 3,\n",
       "             'the': 4,\n",
       "             'and': 5,\n",
       "             'of': 6,\n",
       "             'a': 7,\n",
       "             'to': 8,\n",
       "             'is': 9,\n",
       "             \"'s\": 10,\n",
       "             'that': 11,\n",
       "             'in': 12,\n",
       "             'it': 13,\n",
       "             'The': 14,\n",
       "             'as': 15,\n",
       "             'film': 16,\n",
       "             'with': 17,\n",
       "             'movie': 18,\n",
       "             'but': 19,\n",
       "             'for': 20,\n",
       "             'A': 21,\n",
       "             'its': 22,\n",
       "             'an': 23,\n",
       "             'this': 24,\n",
       "             'you': 25,\n",
       "             \"n't\": 26,\n",
       "             'be': 27,\n",
       "             '...': 28,\n",
       "             'It': 29,\n",
       "             'on': 30,\n",
       "             'by': 31,\n",
       "             '--': 32,\n",
       "             'has': 33,\n",
       "             'are': 34,\n",
       "             'about': 35,\n",
       "             'more': 36,\n",
       "             'not': 37,\n",
       "             'than': 38,\n",
       "             'at': 39,\n",
       "             'from': 40,\n",
       "             'one': 41,\n",
       "             'have': 42,\n",
       "             'I': 43,\n",
       "             'like': 44,\n",
       "             'his': 45,\n",
       "             'all': 46,\n",
       "             \"'\": 47,\n",
       "             'so': 48,\n",
       "             'or': 49,\n",
       "             '-RRB-': 50,\n",
       "             '-LRB-': 51,\n",
       "             'story': 52,\n",
       "             'who': 53,\n",
       "             'into': 54,\n",
       "             'most': 55,\n",
       "             'out': 56,\n",
       "             'does': 57,\n",
       "             'too': 58,\n",
       "             'up': 59,\n",
       "             'just': 60,\n",
       "             'This': 61,\n",
       "             '``': 62,\n",
       "             'comedy': 63,\n",
       "             \"''\": 64,\n",
       "             '`': 65,\n",
       "             'good': 66,\n",
       "             'will': 67,\n",
       "             'can': 68,\n",
       "             'characters': 69,\n",
       "             'much': 70,\n",
       "             'if': 71,\n",
       "             'even': 72,\n",
       "             'no': 73,\n",
       "             'their': 74,\n",
       "             'funny': 75,\n",
       "             'An': 76,\n",
       "             'some': 77,\n",
       "             'time': 78,\n",
       "             'what': 79,\n",
       "             'way': 80,\n",
       "             'only': 81,\n",
       "             'little': 82,\n",
       "             'your': 83,\n",
       "             'which': 84,\n",
       "             'make': 85,\n",
       "             'work': 86,\n",
       "             'been': 87,\n",
       "             'they': 88,\n",
       "             'would': 89,\n",
       "             'makes': 90,\n",
       "             'never': 91,\n",
       "             'very': 92,\n",
       "             'any': 93,\n",
       "             'he': 94,\n",
       "             'us': 95,\n",
       "             'there': 96,\n",
       "             'enough': 97,\n",
       "             'bad': 98,\n",
       "             'do': 99,\n",
       "             'director': 100,\n",
       "             'was': 101,\n",
       "             'life': 102,\n",
       "             'may': 103,\n",
       "             'movies': 104,\n",
       "             'we': 105,\n",
       "             'through': 106,\n",
       "             'best': 107,\n",
       "             'love': 108,\n",
       "             'drama': 109,\n",
       "             'could': 110,\n",
       "             'There': 111,\n",
       "             ':': 112,\n",
       "             'made': 113,\n",
       "             'something': 114,\n",
       "             'really': 115,\n",
       "             'If': 116,\n",
       "             'own': 117,\n",
       "             'performances': 118,\n",
       "             'well': 119,\n",
       "             'plot': 120,\n",
       "             'films': 121,\n",
       "             'action': 122,\n",
       "             'many': 123,\n",
       "             'should': 124,\n",
       "             'better': 125,\n",
       "             'her': 126,\n",
       "             'when': 127,\n",
       "             'how': 128,\n",
       "             'people': 129,\n",
       "             'without': 130,\n",
       "             'other': 131,\n",
       "             'What': 132,\n",
       "             \"'re\": 133,\n",
       "             '?': 134,\n",
       "             'look': 135,\n",
       "             'cast': 136,\n",
       "             'also': 137,\n",
       "             'off': 138,\n",
       "             'see': 139,\n",
       "             'humor': 140,\n",
       "             'ever': 141,\n",
       "             'script': 142,\n",
       "             'sense': 143,\n",
       "             'both': 144,\n",
       "             'nothing': 145,\n",
       "             'still': 146,\n",
       "             'two': 147,\n",
       "             'every': 148,\n",
       "             'fun': 149,\n",
       "             'new': 150,\n",
       "             'audience': 151,\n",
       "             'them': 152,\n",
       "             'those': 153,\n",
       "             'character': 154,\n",
       "             'great': 155,\n",
       "             'might': 156,\n",
       "             'feel': 157,\n",
       "             'long': 158,\n",
       "             'kind': 159,\n",
       "             'such': 160,\n",
       "             '-': 161,\n",
       "             'because': 162,\n",
       "             'first': 163,\n",
       "             'performance': 164,\n",
       "             'being': 165,\n",
       "             'get': 166,\n",
       "             'often': 167,\n",
       "             'One': 168,\n",
       "             'entertaining': 169,\n",
       "             \"'ve\": 170,\n",
       "             'As': 171,\n",
       "             'In': 172,\n",
       "             'here': 173,\n",
       "             'minutes': 174,\n",
       "             'seems': 175,\n",
       "             'tale': 176,\n",
       "             'real': 177,\n",
       "             'But': 178,\n",
       "             'Hollywood': 179,\n",
       "             'between': 180,\n",
       "             'documentary': 181,\n",
       "             ';': 182,\n",
       "             'thriller': 183,\n",
       "             'hard': 184,\n",
       "             'screen': 185,\n",
       "             'few': 186,\n",
       "             'over': 187,\n",
       "             'acting': 188,\n",
       "             'down': 189,\n",
       "             'were': 190,\n",
       "             'heart': 191,\n",
       "             'another': 192,\n",
       "             'feels': 193,\n",
       "             'picture': 194,\n",
       "             'while': 195,\n",
       "             'world': 196,\n",
       "             \"'ll\": 197,\n",
       "             'almost': 198,\n",
       "             'end': 199,\n",
       "             'less': 200,\n",
       "             'comes': 201,\n",
       "             'quite': 202,\n",
       "             'actors': 203,\n",
       "             'itself': 204,\n",
       "             'take': 205,\n",
       "             'year': 206,\n",
       "             'big': 207,\n",
       "             'come': 208,\n",
       "             'interesting': 209,\n",
       "             'seen': 210,\n",
       "             'these': 211,\n",
       "             'family': 212,\n",
       "             'had': 213,\n",
       "             'romantic': 214,\n",
       "             'things': 215,\n",
       "             'before': 216,\n",
       "             'dialogue': 217,\n",
       "             'material': 218,\n",
       "             'moments': 219,\n",
       "             'rather': 220,\n",
       "             'American': 221,\n",
       "             'You': 222,\n",
       "             'watch': 223,\n",
       "             'Like': 224,\n",
       "             'ca': 225,\n",
       "             'far': 226,\n",
       "             'seem': 227,\n",
       "             'works': 228,\n",
       "             '!': 229,\n",
       "             'find': 230,\n",
       "             'thing': 231,\n",
       "             'after': 232,\n",
       "             'back': 233,\n",
       "             'human': 234,\n",
       "             'me': 235,\n",
       "             'scenes': 236,\n",
       "             'watching': 237,\n",
       "             'years': 238,\n",
       "             'yet': 239,\n",
       "             'While': 240,\n",
       "             'lot': 241,\n",
       "             'making': 242,\n",
       "             'original': 243,\n",
       "             'ultimately': 244,\n",
       "             'compelling': 245,\n",
       "             'go': 246,\n",
       "             'least': 247,\n",
       "             'right': 248,\n",
       "             'worth': 249,\n",
       "             'cinema': 250,\n",
       "             'fascinating': 251,\n",
       "             'old': 252,\n",
       "             'where': 253,\n",
       "             'man': 254,\n",
       "             'With': 255,\n",
       "             'direction': 256,\n",
       "             'once': 257,\n",
       "             'young': 258,\n",
       "             'gives': 259,\n",
       "             'piece': 260,\n",
       "             'special': 261,\n",
       "             'takes': 262,\n",
       "             'then': 263,\n",
       "             'gets': 264,\n",
       "             'give': 265,\n",
       "             'going': 266,\n",
       "             'keep': 267,\n",
       "             'moving': 268,\n",
       "             'For': 269,\n",
       "             'experience': 270,\n",
       "             'music': 271,\n",
       "             'our': 272,\n",
       "             'subject': 273,\n",
       "             'times': 274,\n",
       "             'anything': 275,\n",
       "             'bit': 276,\n",
       "             'part': 277,\n",
       "             'style': 278,\n",
       "             'think': 279,\n",
       "             'comic': 280,\n",
       "             'history': 281,\n",
       "             'say': 282,\n",
       "             'she': 283,\n",
       "             'Mr.': 284,\n",
       "             'did': 285,\n",
       "             'dull': 286,\n",
       "             'emotional': 287,\n",
       "             'full': 288,\n",
       "             'him': 289,\n",
       "             'laughs': 290,\n",
       "             'same': 291,\n",
       "             'screenplay': 292,\n",
       "             'visual': 293,\n",
       "             'why': 294,\n",
       "             'art': 295,\n",
       "             'Not': 296,\n",
       "             'know': 297,\n",
       "             'point': 298,\n",
       "             'since': 299,\n",
       "             'sometimes': 300,\n",
       "             'though': 301,\n",
       "             'together': 302,\n",
       "             'actually': 303,\n",
       "             'again': 304,\n",
       "             'away': 305,\n",
       "             'entertainment': 306,\n",
       "             'filmmakers': 307,\n",
       "             'flick': 308,\n",
       "             'genre': 309,\n",
       "             'idea': 310,\n",
       "             'need': 311,\n",
       "             'offers': 312,\n",
       "             'short': 313,\n",
       "             'want': 314,\n",
       "             'Even': 315,\n",
       "             'cinematic': 316,\n",
       "             'dark': 317,\n",
       "             'narrative': 318,\n",
       "             'series': 319,\n",
       "             'show': 320,\n",
       "             'whose': 321,\n",
       "             'All': 322,\n",
       "             'care': 323,\n",
       "             'clever': 324,\n",
       "             'kids': 325,\n",
       "             'whole': 326,\n",
       "             'engaging': 327,\n",
       "             'exercise': 328,\n",
       "             'fans': 329,\n",
       "             'goes': 330,\n",
       "             'manages': 331,\n",
       "             'study': 332,\n",
       "             'title': 333,\n",
       "             'worst': 334,\n",
       "             \"'d\": 335,\n",
       "             'around': 336,\n",
       "             'charm': 337,\n",
       "             'enjoyable': 338,\n",
       "             'feature': 339,\n",
       "             'premise': 340,\n",
       "             'probably': 341,\n",
       "             'simply': 342,\n",
       "             'women': 343,\n",
       "             'New': 344,\n",
       "             'anyone': 345,\n",
       "             'children': 346,\n",
       "             'done': 347,\n",
       "             'effort': 348,\n",
       "             'filmmaking': 349,\n",
       "             'last': 350,\n",
       "             'matter': 351,\n",
       "             'my': 352,\n",
       "             'place': 353,\n",
       "             'predictable': 354,\n",
       "             'smart': 355,\n",
       "             'And': 356,\n",
       "             'familiar': 357,\n",
       "             'nearly': 358,\n",
       "             'portrait': 359,\n",
       "             'surprisingly': 360,\n",
       "             'sweet': 361,\n",
       "             'three': 362,\n",
       "             'Despite': 363,\n",
       "             'Though': 364,\n",
       "             'always': 365,\n",
       "             'becomes': 366,\n",
       "             'energy': 367,\n",
       "             'feeling': 368,\n",
       "             'horror': 369,\n",
       "             'powerful': 370,\n",
       "             'amusing': 371,\n",
       "             'beautiful': 372,\n",
       "             'pretty': 373,\n",
       "             'set': 374,\n",
       "             'wo': 375,\n",
       "             'charming': 376,\n",
       "             'effects': 377,\n",
       "             'face': 378,\n",
       "             'looking': 379,\n",
       "             'quirky': 380,\n",
       "             'romance': 381,\n",
       "             'silly': 382,\n",
       "             'strong': 383,\n",
       "             'true': 384,\n",
       "             'trying': 385,\n",
       "             'wit': 386,\n",
       "             \"'m\": 387,\n",
       "             'French': 388,\n",
       "             'No': 389,\n",
       "             'deeply': 390,\n",
       "             'enjoy': 391,\n",
       "             'fact': 392,\n",
       "             'looks': 393,\n",
       "             'modern': 394,\n",
       "             'plays': 395,\n",
       "             'power': 396,\n",
       "             'rare': 397,\n",
       "             'solid': 398,\n",
       "             'sure': 399,\n",
       "             'tone': 400,\n",
       "             'turns': 401,\n",
       "             'under': 402,\n",
       "             'video': 403,\n",
       "             'beautifully': 404,\n",
       "             'certainly': 405,\n",
       "             'dramatic': 406,\n",
       "             'easy': 407,\n",
       "             'especially': 408,\n",
       "             'half': 409,\n",
       "             'summer': 410,\n",
       "             'believe': 411,\n",
       "             'culture': 412,\n",
       "             'debut': 413,\n",
       "             'intelligent': 414,\n",
       "             'likely': 415,\n",
       "             'put': 416,\n",
       "             'reason': 417,\n",
       "             'recent': 418,\n",
       "             'star': 419,\n",
       "             'theater': 420,\n",
       "             'John': 421,\n",
       "             'everything': 422,\n",
       "             'fine': 423,\n",
       "             'ideas': 424,\n",
       "             'intelligence': 425,\n",
       "             'interest': 426,\n",
       "             'level': 427,\n",
       "             'mess': 428,\n",
       "             'mind': 429,\n",
       "             'now': 430,\n",
       "             'small': 431,\n",
       "             'sort': 432,\n",
       "             'stuff': 433,\n",
       "             'version': 434,\n",
       "             'Its': 435,\n",
       "             'My': 436,\n",
       "             'That': 437,\n",
       "             'above': 438,\n",
       "             'along': 439,\n",
       "             'camera': 440,\n",
       "             'completely': 441,\n",
       "             'directed': 442,\n",
       "             'each': 443,\n",
       "             'else': 444,\n",
       "             'everyone': 445,\n",
       "             'fresh': 446,\n",
       "             'leave': 447,\n",
       "             'must': 448,\n",
       "             'nor': 449,\n",
       "             'problem': 450,\n",
       "             'proves': 451,\n",
       "             'ride': 452,\n",
       "             'stories': 453,\n",
       "             'suspense': 454,\n",
       "             'truly': 455,\n",
       "             'turn': 456,\n",
       "             'TV': 457,\n",
       "             'adventure': 458,\n",
       "             'already': 459,\n",
       "             'audiences': 460,\n",
       "             'boring': 461,\n",
       "             'ending': 462,\n",
       "             'filmmaker': 463,\n",
       "             'flat': 464,\n",
       "             'high': 465,\n",
       "             'hour': 466,\n",
       "             'lack': 467,\n",
       "             'lacks': 468,\n",
       "             'lives': 469,\n",
       "             'obvious': 470,\n",
       "             'pleasure': 471,\n",
       "             'sad': 472,\n",
       "             'satisfying': 473,\n",
       "             'serious': 474,\n",
       "             'shows': 475,\n",
       "             'spirit': 476,\n",
       "             'storytelling': 477,\n",
       "             'touching': 478,\n",
       "             'Just': 479,\n",
       "             'Michael': 480,\n",
       "             'against': 481,\n",
       "             'attempt': 482,\n",
       "             'beyond': 483,\n",
       "             'either': 484,\n",
       "             'hilarious': 485,\n",
       "             'instead': 486,\n",
       "             'opera': 487,\n",
       "             'perfect': 488,\n",
       "             'play': 489,\n",
       "             'war': 490,\n",
       "             'About': 491,\n",
       "             'Although': 492,\n",
       "             'At': 493,\n",
       "             'Too': 494,\n",
       "             'actor': 495,\n",
       "             'classic': 496,\n",
       "             'complex': 497,\n",
       "             'fails': 498,\n",
       "             'himself': 499,\n",
       "             'hours': 500,\n",
       "             'left': 501,\n",
       "             'light': 502,\n",
       "             'line': 503,\n",
       "             'melodrama': 504,\n",
       "             'particularly': 505,\n",
       "             'rich': 506,\n",
       "             'satire': 507,\n",
       "             'sequel': 508,\n",
       "             'shot': 509,\n",
       "             'By': 510,\n",
       "             'We': 511,\n",
       "             'become': 512,\n",
       "             'book': 513,\n",
       "             'easily': 514,\n",
       "             'neither': 515,\n",
       "             'past': 516,\n",
       "             'pretentious': 517,\n",
       "             'production': 518,\n",
       "             'social': 519,\n",
       "             'terrific': 520,\n",
       "             'tries': 521,\n",
       "             'viewers': 522,\n",
       "             'ways': 523,\n",
       "             'woman': 524,\n",
       "             'Director': 525,\n",
       "             'animation': 526,\n",
       "             'day': 527,\n",
       "             'different': 528,\n",
       "             'formula': 529,\n",
       "             'head': 530,\n",
       "             'honest': 531,\n",
       "             'images': 532,\n",
       "             'imagination': 533,\n",
       "             'jokes': 534,\n",
       "             'lost': 535,\n",
       "             'message': 536,\n",
       "             'remarkable': 537,\n",
       "             'role': 538,\n",
       "             'scene': 539,\n",
       "             'seeing': 540,\n",
       "             'slow': 541,\n",
       "             'wonderful': 542,\n",
       "             'written': 543,\n",
       "             'To': 544,\n",
       "             'bland': 545,\n",
       "             'cliches': 546,\n",
       "             'coming-of-age': 547,\n",
       "             'engrossing': 548,\n",
       "             'found': 549,\n",
       "             'got': 550,\n",
       "             'impossible': 551,\n",
       "             'inside': 552,\n",
       "             'intriguing': 553,\n",
       "             'leaves': 554,\n",
       "             'mood': 555,\n",
       "             'mystery': 556,\n",
       "             'psychological': 557,\n",
       "             'reality': 558,\n",
       "             'remains': 559,\n",
       "             'running': 560,\n",
       "             'sequences': 561,\n",
       "             'simple': 562,\n",
       "             'themselves': 563,\n",
       "             'thoroughly': 564,\n",
       "             'thoughtful': 565,\n",
       "             'tired': 566,\n",
       "             'writing': 567,\n",
       "             'Love': 568,\n",
       "             'Time': 569,\n",
       "             'When': 570,\n",
       "             'brilliant': 571,\n",
       "             'comedies': 572,\n",
       "             'crime': 573,\n",
       "             'delivers': 574,\n",
       "             'despite': 575,\n",
       "             'dumb': 576,\n",
       "             'emotionally': 577,\n",
       "             'ends': 578,\n",
       "             'events': 579,\n",
       "             'eyes': 580,\n",
       "             'fairly': 581,\n",
       "             'finally': 582,\n",
       "             'having': 583,\n",
       "             'help': 584,\n",
       "             'hero': 585,\n",
       "             'historical': 586,\n",
       "             'job': 587,\n",
       "             'journey': 588,\n",
       "             'live': 589,\n",
       "             'memorable': 590,\n",
       "             'men': 591,\n",
       "             'passion': 592,\n",
       "             'project': 593,\n",
       "             'soap': 594,\n",
       "             'surprising': 595,\n",
       "             'told': 596,\n",
       "             'viewer': 597,\n",
       "             'violence': 598,\n",
       "             'vision': 599,\n",
       "             'wrong': 600,\n",
       "             'Robert': 601,\n",
       "             'case': 602,\n",
       "             'change': 603,\n",
       "             'cold': 604,\n",
       "             'delightful': 605,\n",
       "             'entirely': 606,\n",
       "             'excellent': 607,\n",
       "             'falls': 608,\n",
       "             'gags': 609,\n",
       "             'given': 610,\n",
       "             'impressive': 611,\n",
       "             'moment': 612,\n",
       "             'otherwise': 613,\n",
       "             'political': 614,\n",
       "             'talent': 615,\n",
       "             'teen': 616,\n",
       "             'tragedy': 617,\n",
       "             'uses': 618,\n",
       "             'usual': 619,\n",
       "             'whether': 620,\n",
       "             'After': 621,\n",
       "             'Disney': 622,\n",
       "             'He': 623,\n",
       "             'Spielberg': 624,\n",
       "             'able': 625,\n",
       "             'adults': 626,\n",
       "             'air': 627,\n",
       "             'captures': 628,\n",
       "             'contrived': 629,\n",
       "             'days': 630,\n",
       "             'death': 631,\n",
       "             'flaws': 632,\n",
       "             'imagine': 633,\n",
       "             'latest': 634,\n",
       "             'mostly': 635,\n",
       "             'next': 636,\n",
       "             'stupid': 637,\n",
       "             'unsettling': 638,\n",
       "             '2': 639,\n",
       "             'David': 640,\n",
       "             'Has': 641,\n",
       "             'II': 642,\n",
       "             'Nothing': 643,\n",
       "             'acted': 644,\n",
       "             'amount': 645,\n",
       "             'animated': 646,\n",
       "             'appeal': 647,\n",
       "             'appealing': 648,\n",
       "             'attention': 649,\n",
       "             'barely': 650,\n",
       "             'concept': 651,\n",
       "             'deep': 652,\n",
       "             'during': 653,\n",
       "             'entire': 654,\n",
       "             'fantasy': 655,\n",
       "             'gentle': 656,\n",
       "             'getting': 657,\n",
       "             'gone': 658,\n",
       "             'home': 659,\n",
       "             'lead': 660,\n",
       "             'low': 661,\n",
       "             'offer': 662,\n",
       "             'old-fashioned': 663,\n",
       "             'parents': 664,\n",
       "             'period': 665,\n",
       "             'personal': 666,\n",
       "             'possible': 667,\n",
       "             'relationships': 668,\n",
       "             'remake': 669,\n",
       "             'scary': 670,\n",
       "             'sci-fi': 671,\n",
       "             'sentimental': 672,\n",
       "             'sex': 673,\n",
       "             'sit': 674,\n",
       "             'talented': 675,\n",
       "             'tedious': 676,\n",
       "             'tell': 677,\n",
       "             'themes': 678,\n",
       "             'thin': 679,\n",
       "             'try': 680,\n",
       "             'ugly': 681,\n",
       "             'welcome': 682,\n",
       "             'words': 683,\n",
       "             'Allen': 684,\n",
       "             'Big': 685,\n",
       "             'De': 686,\n",
       "             'Full': 687,\n",
       "             'Grant': 688,\n",
       "             'Is': 689,\n",
       "             'More': 690,\n",
       "             'So': 691,\n",
       "             'Some': 692,\n",
       "             'They': 693,\n",
       "             'York': 694,\n",
       "             'across': 695,\n",
       "             'act': 696,\n",
       "             'adaptation': 697,\n",
       "             'ambitious': 698,\n",
       "             'artist': 699,\n",
       "             'awful': 700,\n",
       "             'bring': 701,\n",
       "             'close': 702,\n",
       "             'creepy': 703,\n",
       "             'depth': 704,\n",
       "             'elements': 705,\n",
       "             'examination': 706,\n",
       "             'felt': 707,\n",
       "             'female': 708,\n",
       "             'final': 709,\n",
       "             'hit': 710,\n",
       "             'laugh': 711,\n",
       "             'magic': 712,\n",
       "             'nice': 713,\n",
       "             'occasionally': 714,\n",
       "             'odd': 715,\n",
       "             'overall': 716,\n",
       "             'plenty': 717,\n",
       "             'quality': 718,\n",
       "             'rarely': 719,\n",
       "             'sharp': 720,\n",
       "             'side': 721,\n",
       "             'someone': 722,\n",
       "             'straight': 723,\n",
       "             'surprise': 724,\n",
       "             'thought': 725,\n",
       "             'understand': 726,\n",
       "             'urban': 727,\n",
       "             'view': 728,\n",
       "             'visually': 729,\n",
       "             'Chan': 730,\n",
       "             'Every': 731,\n",
       "             'Jackson': 732,\n",
       "             'Thing': 733,\n",
       "             'War': 734,\n",
       "             'add': 735,\n",
       "             'approach': 736,\n",
       "             'black': 737,\n",
       "             'casting': 738,\n",
       "             'clear': 739,\n",
       "             'clichés': 740,\n",
       "             'crafted': 741,\n",
       "             'creative': 742,\n",
       "             'difficult': 743,\n",
       "             'emotions': 744,\n",
       "             'ensemble': 745,\n",
       "             'epic': 746,\n",
       "             'execution': 747,\n",
       "             'four': 748,\n",
       "             'frame': 749,\n",
       "             'future': 750,\n",
       "             'game': 751,\n",
       "             'genuine': 752,\n",
       "             'gorgeous': 753,\n",
       "             'issues': 754,\n",
       "             'keeps': 755,\n",
       "             'knows': 756,\n",
       "             'lacking': 757,\n",
       "             'moral': 758,\n",
       "             'pictures': 759,\n",
       "             'pieces': 760,\n",
       "             'promise': 761,\n",
       "             'provocative': 762,\n",
       "             'pure': 763,\n",
       "             'rest': 764,\n",
       "             'result': 765,\n",
       "             'run': 766,\n",
       "             'single': 767,\n",
       "             'stylish': 768,\n",
       "             'sustain': 769,\n",
       "             'taste': 770,\n",
       "             'touch': 771,\n",
       "             'use': 772,\n",
       "             'utterly': 773,\n",
       "             'watchable': 774,\n",
       "             'weird': 775,\n",
       "             'winning': 776,\n",
       "             'worse': 777,\n",
       "             '2002': 778,\n",
       "             '90': 779,\n",
       "             'America': 780,\n",
       "             'Home': 781,\n",
       "             'Steven': 782,\n",
       "             'apart': 783,\n",
       "             'boy': 784,\n",
       "             'call': 785,\n",
       "             'cheap': 786,\n",
       "             'coming': 787,\n",
       "             'contemporary': 788,\n",
       "             'convincing': 789,\n",
       "             'decent': 790,\n",
       "             'definitely': 791,\n",
       "             'doubt': 792,\n",
       "             'exactly': 793,\n",
       "             'expect': 794,\n",
       "             'eye': 795,\n",
       "             'hand': 796,\n",
       "             'hope': 797,\n",
       "             'important': 798,\n",
       "             'insight': 799,\n",
       "             'liked': 800,\n",
       "             'major': 801,\n",
       "             'masterpiece': 802,\n",
       "             'memory': 803,\n",
       "             'nature': 804,\n",
       "             'needs': 805,\n",
       "             'none': 806,\n",
       "             'novel': 807,\n",
       "             'ones': 808,\n",
       "             'pace': 809,\n",
       "             'pacing': 810,\n",
       "             'perfectly': 811,\n",
       "             'playing': 812,\n",
       "             'previous': 813,\n",
       "             'process': 814,\n",
       "             'puts': 815,\n",
       "             'sensitive': 816,\n",
       "             'sets': 817,\n",
       "             'several': 818,\n",
       "             'situations': 819,\n",
       "             'somewhat': 820,\n",
       "             'start': 821,\n",
       "             'subtle': 822,\n",
       "             'success': 823,\n",
       "             'throughout': 824,\n",
       "             'unexpected': 825,\n",
       "             'wants': 826,\n",
       "             'warm': 827,\n",
       "             'warmth': 828,\n",
       "             'waste': 829,\n",
       "             'Crush': 830,\n",
       "             'George': 831,\n",
       "             'Moore': 832,\n",
       "             'Movie': 833,\n",
       "             'Ms.': 834,\n",
       "             'Much': 835,\n",
       "             'Murphy': 836,\n",
       "             'Spy': 837,\n",
       "             'ability': 838,\n",
       "             'absolutely': 839,\n",
       "             'ago': 840,\n",
       "             'among': 841,\n",
       "             'begins': 842,\n",
       "             'brings': 843,\n",
       "             'called': 844,\n",
       "             'career': 845,\n",
       "             'conventional': 846,\n",
       "             'cool': 847,\n",
       "             'couple': 848,\n",
       "             'create': 849,\n",
       "             'creates': 850,\n",
       "             'cultural': 851,\n",
       "             'deserves': 852,\n",
       "             'dry': 853,\n",
       "             'earnest': 854,\n",
       "             'episode': 855,\n",
       "             'fare': 856,\n",
       "             'form': 857,\n",
       "             'frequently': 858,\n",
       "             'generic': 859,\n",
       "             'girl': 860,\n",
       "             'hold': 861,\n",
       "             'huge': 862,\n",
       "             'insightful': 863,\n",
       "             'kid': 864,\n",
       "             'leads': 865,\n",
       "             'lots': 866,\n",
       "             'loud': 867,\n",
       "             'middle': 868,\n",
       "             'mix': 869,\n",
       "             'problems': 870,\n",
       "             'quiet': 871,\n",
       "             'saw': 872,\n",
       "             'school': 873,\n",
       "             'sexy': 874,\n",
       "             'slightly': 875,\n",
       "             'sounds': 876,\n",
       "             'strange': 877,\n",
       "             'successful': 878,\n",
       "             'supposed': 879,\n",
       "             'surprises': 880,\n",
       "             'thanks': 881,\n",
       "             'treat': 882,\n",
       "             'tribute': 883,\n",
       "             'truth': 884,\n",
       "             'unfunny': 885,\n",
       "             'viewing': 886,\n",
       "             'witty': 887,\n",
       "             'word': 888,\n",
       "             'working': 889,\n",
       "             'worthy': 890,\n",
       "             'writer': 891,\n",
       "             'writer-director': 892,\n",
       "             'British': 893,\n",
       "             'Girl': 894,\n",
       "             'Never': 895,\n",
       "             'Oscar': 896,\n",
       "             'Sandler': 897,\n",
       "             'Solondz': 898,\n",
       "             'Watching': 899,\n",
       "             'Williams': 900,\n",
       "             'actress': 901,\n",
       "             'allows': 902,\n",
       "             'battle': 903,\n",
       "             'beauty': 904,\n",
       "             'believable': 905,\n",
       "             'clearly': 906,\n",
       "             'colorful': 907,\n",
       "             'company': 908,\n",
       "             'complete': 909,\n",
       "             'considerable': 910,\n",
       "             'core': 911,\n",
       "             'created': 912,\n",
       "             'depressing': 913,\n",
       "             'derivative': 914,\n",
       "             'drag': 915,\n",
       "             'edge': 916,\n",
       "             'effect': 917,\n",
       "             'endearing': 918,\n",
       "             'equally': 919,\n",
       "             'except': 920,\n",
       "             'exciting': 921,\n",
       "             'figure': 922,\n",
       "             'finds': 923,\n",
       "             'flawed': 924,\n",
       "             'flicks': 925,\n",
       "             'fully': 926,\n",
       "             'funnier': 927,\n",
       "             'guys': 928,\n",
       "             'hardly': 929,\n",
       "             'interested': 930,\n",
       "             'inventive': 931,\n",
       "             'lets': 932,\n",
       "             'living': 933,\n",
       "             'lovely': 934,\n",
       "             'manipulative': 935,\n",
       "             'manner': 936,\n",
       "             'match': 937,\n",
       "             'merely': 938,\n",
       "             'monster': 939,\n",
       "             'name': 940,\n",
       "             'others': 941,\n",
       "             'overcome': 942,\n",
       "             'painful': 943,\n",
       "             'painfully': 944,\n",
       "             'places': 945,\n",
       "             'poetry': 946,\n",
       "             'pointless': 947,\n",
       "             'provides': 948,\n",
       "             'quickly': 949,\n",
       "             'recommend': 950,\n",
       "             'reveals': 951,\n",
       "             'runs': 952,\n",
       "             'second': 953,\n",
       "             'seriously': 954,\n",
       "             'sexual': 955,\n",
       "             'sitting': 956,\n",
       "             'situation': 957,\n",
       "             'slight': 958,\n",
       "             'sophisticated': 959,\n",
       "             'spectacle': 960,\n",
       "             'stand': 961,\n",
       "             'starts': 962,\n",
       "             'succeeds': 963,\n",
       "             'taking': 964,\n",
       "             'target': 965,\n",
       "             'terms': 966,\n",
       "             'thinking': 967,\n",
       "             'toward': 968,\n",
       "             'treatment': 969,\n",
       "             'twists': 970,\n",
       "             'uneven': 971,\n",
       "             'unique': 972,\n",
       "             'upon': 973,\n",
       "             'usually': 974,\n",
       "             'vivid': 975,\n",
       "             'Both': 976,\n",
       "             'Brown': 977,\n",
       "             'Bullock': 978,\n",
       "             'Does': 979,\n",
       "             'From': 980,\n",
       "             'How': 981,\n",
       "             'Instead': 982,\n",
       "             'Kids': 983,\n",
       "             'King': 984,\n",
       "             'Soderbergh': 985,\n",
       "             'Those': 986,\n",
       "             'World': 987,\n",
       "             'ages': 988,\n",
       "             'attempts': 989,\n",
       "             'based': 990,\n",
       "             'behind': 991,\n",
       "             'car': 992,\n",
       "             'chemistry': 993,\n",
       "             'cinematography': 994,\n",
       "             'class': 995,\n",
       "             'college': 996,\n",
       "             'conclusion': 997,\n",
       "             'credit': 998,\n",
       "             'cut': 999,\n",
       "             ...})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text.vocab.stoi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CBoW(nn.Module):\n",
    "    def __init__(self, input_size, num_classes, batch_size):\n",
    "        super(CBoW, self).__init__()\n",
    "        self.embeddings = nn.Embedding(text.vocab.vectors.size()[0], text.vocab.vectors.size()[1])\n",
    "        self.embeddings.weight.data.copy_(text.vocab.vectors)\n",
    "        self.linear = nn.Linear(input_size+1, num_classes, bias = True)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x, lengths = x\n",
    "        lengths = Variable(lengths.view(-1, 1).float())\n",
    "        embedded = self.embeddings(x)\n",
    "        average_embed = embedded.mean(0)\n",
    "        concat = torch.cat([average_embed, lengths], dim = 1) # add lengths as a feature\n",
    "        output = self.linear(concat)\n",
    "        logits = torch.nn.functional.log_softmax(output, dim = 1)\n",
    "        return logits\n",
    "\n",
    "    def predict(self, x):\n",
    "        logits = self.forward(x)\n",
    "        return logits.max(1)[1] + 1\n",
    "    \n",
    "    def train(self, train_iter, val_iter, test_iter, num_epochs, learning_rate = 1e-3):\n",
    "        criterion = torch.nn.NLLLoss()\n",
    "        optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n",
    "        loss_vec = []\n",
    "        \n",
    "        for epoch in tqdm_notebook(range(1, num_epochs + 1)):\n",
    "            epoch_loss = 0\n",
    "            for batch in train_iter:\n",
    "                x = batch.text\n",
    "                y = batch.label\n",
    "                \n",
    "                optimizer.zero_grad()\n",
    "                \n",
    "                y_p = self.forward(x)\n",
    "                \n",
    "                loss = criterion(y_p, y-1)\n",
    "                loss.backward()\n",
    "                \n",
    "                optimizer.step()\n",
    "                epoch_loss += loss.data[0]\n",
    "                \n",
    "            self.model = model\n",
    "            \n",
    "            loss_vec.append(epoch_loss / len(train_iter))\n",
    "            if epoch % 1 == 0:\n",
    "                acc = self.validate(val_iter)\n",
    "                print('Epoch {} loss: {} | acc: {}'.format(epoch, loss_vec[epoch-1], acc))\n",
    "                self.model = model\n",
    "                self.test(test_iter)\n",
    "                \n",
    "        plt.plot(range(len(loss_vec)), loss_vec)\n",
    "        plt.xlabel('Epoch')\n",
    "        plt.ylabel('Loss')\n",
    "        plt.show()\n",
    "        print('\\nModel trained.\\n')\n",
    "        self.loss_vec = loss_vec\n",
    "        self.model = model\n",
    "\n",
    "    def test(self, test_iter):\n",
    "        \"All models should be able to be run with following command.\"\n",
    "        upload, trues = [], []\n",
    "        # Update: for kaggle the bucket iterator needs to have batch_size 10\n",
    "        for batch in test_iter:\n",
    "            # Your prediction data here (don't cheat!)\n",
    "            x, y = batch.text, batch.label\n",
    "            preds = self.predict(x)\n",
    "            upload += list(preds.data.numpy())\n",
    "            trues += list(y.data.numpy())\n",
    "            \n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(upload, trues)])\n",
    "        accuracy = correct / len(trues)\n",
    "        print('Test Accuracy:', accuracy)\n",
    "\n",
    "        with open(\"predictions.txt\", \"w\") as f:\n",
    "            for u in upload:\n",
    "                f.write(str(u) + \"\\n\")\n",
    "                \n",
    "    def validate(self, val_iter):\n",
    "        y_p, y_t, correct = [], [], 0\n",
    "        for batch in val_iter:\n",
    "            x, y = batch.text, batch.label\n",
    "            probs = self.model.predict(x)[:len(y.data.numpy())]\n",
    "            y_p += list(probs.data.numpy())\n",
    "            y_t += list(y.data.numpy())\n",
    "        correct = sum([1 if i == j else 0 for i, j in zip(y_p, y_t)])\n",
    "        accuracy = correct / len(y_p)\n",
    "        return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CBoW(input_size = 300, num_classes = 2, batch_size = 10)\n",
    "model.train(train_iter = train_iter, val_iter = val_iter, test_iter = test_iter, num_epochs = 15, learning_rate = 1e-4, plot = False)\n",
    "model.test(test_iter)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
